
In this notebook we'll be exploring the impact of chunking choices for dask arrays. We'll use an ODC example but this isn't specific to ODC, it applies to all usage of dask Arrays. Chunking choices have a significant impact on performance for three reasons:
Performance is thus impacted in multiple ways - this is all about tradeoffs:
It's not just size that matters either, the relative contiguity of dimensions matters:
Thankfully it is possible to re-chunk data for different stages of computation. Whilst re-chunking is an expensive operation the efficiency gains for downstream computation can be very significant and sometimes are simply essential to support the numerical processing required. For example, it is often necessary to have a single chunk on the time dimension for temporal calculations.
To understand the impact of chunking choices on your code (it is very algorithm dependent) it is essential to understand both the:
Dask provides tools for viewing all of these when you print out arrays in the notebook (static) and when viewing the various graphs in the dask dashboard (dynamic).
The code below will be familiar, it's the same example from previous notebooks (seasonal mean NDVI over a large area). A normalised burn ratio (NBR2) calculation has been added as well to provide some additional load to assist in making the performance differences more noticeable in various graphs. The NBR2 uses two additional bands but is effectively the same type of calculation as the NDVI (a normalised difference ratio).
The primary difference for this example is the calculation (both NDVI and NBR) is performed 4 times, each with a different chunking regime. See the chunk_settings list.
When running this notebook, be sure to have the dask dashboard open and preferably visible as calculations proceed.
There are several sections to pay attention too:
All of these graphs are dynamic and should be interpreted over time.
The dask scheduler itself is also dynamic and as your code executes it stores information about how the tasks are executing and the communication occuring and adjusts scheduling accordingly. It can take a few minutes for the scheduler to settle into a true pattern. That pattern may also change, particularly in latter parts of a computation when work is completing and there are fewer tasks to execute.
Yes, that is a LOT of information. Thankfully you don't necessarily need to learn it all at once. In time reading the information available will become easier as will knowing what to do about it.
Now let's run this notebook, remember to watch the execution in the Dask Dashboard.
Tip: It's likely you will want to repeat the calculation in this notebook several times. Because the results are
persistedto the cluster simply calling it again will result in no execution (none is required because it waspersisted). Rather than doingcluster.shutdown()and creating a new cluster each time you can clear thepersistedresult by performing aclient.restart(). This will clear out all previous calculations so you canpersistagain. You can do this by either creating a new cell or using a Python Console for this Notebook (right click on the notebook and select New Console for Notebook).
A modest cluster will do... and Open the dashboard
# Initialize the Gateway client
from dask.distributed import Client
from dask_gateway import Gateway
number_of_workers = 5
gateway = Gateway()
clusters = gateway.list_clusters()
if not clusters:
print('Creating new cluster. Please wait for this to finish.')
cluster = gateway.new_cluster()
else:
print(f'An existing cluster was found. Connecting to: {clusters[0].name}')
cluster=gateway.connect(clusters[0].name)
cluster.scale(number_of_workers)
client = cluster.get_client()
client
An existing cluster was found. Connecting to: easihub.02bfa9279d60429d9c28531382511a40
Client-6a679f93-0460-11ee-9139-06a3e810ea18
| Connection method: Cluster object | Cluster type: dask_gateway.GatewayCluster |
| Dashboard: https://hub.csiro.easi-eo.solutions/services/dask-gateway/clusters/easihub.02bfa9279d60429d9c28531382511a40/status |
Nothing special here
import pyproj
pyproj.set_use_global_context(True)
import git
import sys, os
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from dask.distributed import Client, LocalCluster, wait
import datacube
from datacube.utils import masking
from datacube.utils.aws import configure_s3_access
# EASI defaults
os.environ['USE_PYGEOS'] = '0'
repo = git.Repo('.', search_parent_directories=True).working_tree_dir
if repo not in sys.path: sys.path.append(repo)
from easi_tools import EasiDefaults, notebook_utils
easi = EasiDefaults()
Successfully found configuration for deployment "csiro"
dc = datacube.Datacube()
configure_s3_access(aws_unsigned=False, requester_pays=True, client=client);
# Get the centroid of the coordinates of the default extents
central_lat = sum(easi.latitude)/2
central_lon = sum(easi.longitude)/2
# central_lat = -42.019
# central_lon = 146.615
# Set the buffer to load around the central coordinates
# This is a radial distance for the bbox to actual area so bbox 2x buffer in both dimensions
buffer = 2
# Compute the bounding box for the study area
study_area_lat = (central_lat - buffer, central_lat + buffer)
study_area_lon = (central_lon - buffer, central_lon + buffer)
# Data product
products = easi.product('landsat')
# Set the date range to load data over
set_time = easi.time
set_time = (set_time[0], parse(set_time[0]) + relativedelta(months=6))
#set_time = ("2021-07-01", "2021-12-31")
# Selected measurement names (used in this notebook)
alias = easi.aliases('landsat')
measurements = [alias[x] for x in ['qa_band', 'red', 'green', 'blue', 'nir', 'swir1', 'swir2']]
# Set the QA band name and mask values
qa_band = alias['qa_band']
qa_mask = easi.qa_mask('landsat')
# Set the resampling method for the bands
resampling = {qa_band: "nearest", "*": "average"}
# Set the coordinate reference system and output resolution
set_crs = easi.crs('landsat') # If defined, else None
set_resolution = easi.resolution('landsat') # If defined, else None
# set_crs = "epsg:3577"
# set_resolution = (-30, 30)
# Set the scene group_by method
group_by = "solar_day"
def calc_ndvi(dataset):
# Calculate the components that make up the NDVI calculation
band_diff = dataset[alias['nir']] - dataset[alias['red']]
band_sum = dataset[alias['nir']] + dataset[alias['red']]
# Calculate NDVI
ndvi = band_diff / band_sum
return ndvi
def calc_nbr2(dataset):
# Calculate the components that make up the NDVI calculation
band_diff = dataset[alias['swir1']] - dataset[alias['swir2']]
band_sum = dataset[alias['swir1']] + dataset[alias['swir2']]
# Calculate NBR2
nbr2 = band_diff / band_sum
return nbr2
def mask(dataset, bands):
# Identify pixels that are either "valid", "water" or "snow"
cloud_free_mask = masking.make_mask(dataset[qa_band], **qa_mask)
# Apply the mask
cloud_free = dataset[bands].astype('float32').where(cloud_free_mask)
return cloud_free
def seasonal_mean(dataset):
return dataset.resample(time="QS-DEC").mean('time') # perform the seasonal mean for each quarter
We have an array of chunk settings to trial.
chunk_settings are nominally the same sizeThere are two time:1 chunks because 50 doesn't have a clean sqrt. The first is the nearest square, the second simply changes the chunks to be rectangles (no one said the spatial dimensions needed to be the same).
Given the chunk size in memory is roughly the same, the cluster the same, the calculation the same - any differences in execution are a result of the different chunking shape.
chunk_settings = [
{"time":50, "x":1024, "y":1024},
{"time":25, "x":1*1024, "y":2*1024},
{"time":1, "x":7*1024, "y":7*1024},
{"time":1, "x":5*1024, "y":10*1024},
]
Now we can loop over all our chunk_settings and create all the required delayed task graphs. This will take a moment as the ODC database will be interogated for all the necessary dataset information.
You will notice the calculation is split up so we can see the interim results - well the last one at least given its a loop and we're overwriting them.
Different stages of computation will produce different data types and calculations and thus chunk and task counts. We may find that an interim result has a terrible chunk size (e.g. int16 data variables become float64 and thus your chunks are now 4x the size, or a dimension is reduced and chunks are too small). It is thus advisable when tuning to make it possible to view these interim stages to see the static impact.
Remember: there is a single task graph executing to provide the final result. There is no need to persist() or compute() the interim results to see their static attributes. In fact, it may be unwise to persist() as this will chew up resources on the cluster if you don't intend on using the results.
results = []
for chunks in chunk_settings:
dataset = dc.load(
product=products,
x=study_area_lon,
y=study_area_lat,
time=set_time,
measurements=measurements,
resampling=resampling,
output_crs=set_crs,
resolution=set_resolution,
dask_chunks = chunks,
group_by=group_by,
)
masked_dataset = mask(dataset, [alias[x] for x in ['red', 'nir', 'swir1', 'swir2']])
ndvi = calc_ndvi(masked_dataset)
nbr2 = calc_nbr2(masked_dataset)
seasonal_mean_ndvi = seasonal_mean(ndvi)
seasonal_mean_nbr2 = seasonal_mean(nbr2)
seasonal_mean_ndvi.name = 'ndvi'
seasonal_mean_nbr2.name = 'nbr2'
results.append([seasonal_mean_ndvi, seasonal_mean_nbr2])
Lets take a look at the vital statistics for the final iteration of the loop. All the calculations are the same, just the chunk parameters vary so we can infer easily from these what else is happening for the static parameters.
print(f"dataset size (GiB) {dataset.nbytes / 2**30:.2f}")
print(f"seasonal_mean_ndvi size (GiB) {seasonal_mean_ndvi.nbytes / 2**30:.2f}")
display(dataset)
dataset size (GiB) 119.90 seasonal_mean_ndvi size (GiB) 2.46
<xarray.Dataset>
Dimensions: (time: 45, y: 16097, x: 13672)
Coordinates:
* time (time) datetime64[ns] 2020-02-01T23:50:22.661832 ... 2020-0...
* y (y) float64 -3.777e+06 -3.778e+06 ... -4.26e+06 -4.26e+06
* x (x) float64 1.151e+06 1.152e+06 ... 1.562e+06 1.562e+06
spatial_ref int32 3577
Data variables:
oa_fmask (time, y, x) uint8 dask.array<chunksize=(1, 10240, 5120), meta=np.ndarray>
nbart_red (time, y, x) int16 dask.array<chunksize=(1, 10240, 5120), meta=np.ndarray>
nbart_green (time, y, x) int16 dask.array<chunksize=(1, 10240, 5120), meta=np.ndarray>
nbart_blue (time, y, x) int16 dask.array<chunksize=(1, 10240, 5120), meta=np.ndarray>
nbart_nir (time, y, x) int16 dask.array<chunksize=(1, 10240, 5120), meta=np.ndarray>
nbart_swir_1 (time, y, x) int16 dask.array<chunksize=(1, 10240, 5120), meta=np.ndarray>
nbart_swir_2 (time, y, x) int16 dask.array<chunksize=(1, 10240, 5120), meta=np.ndarray>
Attributes:
crs: EPSG:3577
grid_mapping: spatial_refSo the source dataset is 150 GB in size - mostly int16 data type. We need to be mindful that our calculation will convert these to floats. The code above does an explicit type conversion to float32 which can fully represent an int16. Without the explicit type conversion, Python would use float64 resulting in double the memory usage for no good reason (for this algorithm).
Open the cylinder to show the red dask array details. The chunk is about 100 MiB in size. Generally this is a healthy size though it can be larger and may need to be smaller depending on the calculation involved and communication between workers.
Now let's look at the results for the NDVI and NBR:
display(results[0][0])
display(results[0][1])
<xarray.DataArray 'ndvi' (time: 3, y: 16097, x: 13672)>
dask.array<stack, shape=(3, 16097, 13672), dtype=float32, chunksize=(1, 1024, 1024), chunktype=numpy.ndarray>
Coordinates:
* y (y) float64 -3.777e+06 -3.778e+06 ... -4.26e+06 -4.26e+06
* x (x) float64 1.151e+06 1.152e+06 ... 1.562e+06 1.562e+06
spatial_ref int32 3577
* time (time) datetime64[ns] 2019-12-01 2020-03-01 2020-06-01<xarray.DataArray 'nbr2' (time: 3, y: 16097, x: 13672)>
dask.array<stack, shape=(3, 16097, 13672), dtype=float32, chunksize=(1, 1024, 1024), chunktype=numpy.ndarray>
Coordinates:
* y (y) float64 -3.777e+06 -3.778e+06 ... -4.26e+06 -4.26e+06
* x (x) float64 1.151e+06 1.152e+06 ... 1.562e+06 1.562e+06
spatial_ref int32 3577
* time (time) datetime64[ns] 2019-12-01 2020-03-01 2020-06-01Notice the result is much smaller in chunk size - 4 MiB. This is due to the seasonal mean. This may have an impact on downstream usage of the result as the chunks may be too small and result in too many tasks reducing later processing performance.
Notice also the Task count. With both results we're pushing towards 100_000 tasks in the scheduler depending on task graph optimisation. The Scheduler has its own overheads (about 1ms per active task, and memory usage for tracking all tasks, including executed ones as it keeps the history in case it needs to reproduce the results e.g. if a worker is lost). Again, it is possible to have more than 100_000 tasks and be efficient depending on your algorithm but its something to keep an eye on. We will be below it in this case (especially after optimisation).
Theoretically we could persist all of the results at once - though we would be well above the 100_000 task limit if we did.
More importantly we actually want to see the difference in the dynamics of the execution.
The loop below will persist each result one at a time and wait() for it to be complete.
You should monitor execution in the Dask Dashboard
Look at the various tabs as execution proceeds. you will notice differences in memory per worker, Communication between workers (red bars in the Task Stream), white space (idle time), and CPU utilisation (remember to click on the CPU tab to get to this detail).
The Tasks section of the dashboard is particularly useful at looking at a comparison of all four runs' dynamics as the length of all calculations means this snapshot still show all four blocks of computation at once.
Don't forget, if you want to run the code again use client.restart() to clear out the previous results from the cluster.
client.wait_for_workers(n_workers=number_of_workers)
for i, result in enumerate(results):
print(f'Chunks: {chunk_settings[i]}')
client.restart()
f = client.persist(result)
%time wait(f)
client.restart() # clearing the cluster out so each run it cleanly separated
print()
Chunks: {'time': 50, 'x': 1024, 'y': 1024}
CPU times: user 1.89 s, sys: 155 ms, total: 2.05 s
Wall time: 4min 52s
Chunks: {'time': 25, 'x': 1024, 'y': 2048}
CPU times: user 1.28 s, sys: 60.8 ms, total: 1.34 s
Wall time: 4min 18s
Chunks: {'time': 1, 'x': 7168, 'y': 7168}
CPU times: user 558 ms, sys: 19.1 ms, total: 577 ms
Wall time: 3min 28s
Chunks: {'time': 1, 'x': 5120, 'y': 10240}
CPU times: user 602 ms, sys: 4.87 ms, total: 606 ms
Wall time: 4min 21s
Disconnecting your client is good practice, but the cluster will still be up so we need to shut it down as well
client.close()
cluster.shutdown()